Support Flashinfer rope+quant+cache update fusion kernel for TRTLLM attention#36858
Support Flashinfer rope+quant+cache update fusion kernel for TRTLLM attention#36858elvischenv wants to merge 4 commits intovllm-project:mainfrom
Conversation
There was a problem hiding this comment.
Code Review
This PR introduces support for Flashinfer's fused RoPE, quantization, and KV cache update kernel, which is a great performance optimization for FP8 models on CUDA. The changes are well-structured, adding a new RopeQuantReshapeKVCachePattern to handle the fusion and updating related components to support it.
However, I've found a critical issue in vllm/v1/attention/backends/flashinfer.py where a check for KV cache sharing was removed, which could lead to incorrect behavior for models that use this feature. Please see my comment for details.
76992c4 to
ed31eaa
Compare
ed31eaa to
cb4d5e7
Compare
ProExpertProg
left a comment
There was a problem hiding this comment.
I see now, the kernel requires attention metadata which is not built during PIECEWISE warmup/capture. We can keep it excluded for now but we should collect some perf numbers for this kernel inside/outside cudagraphs to see how much this hurts us. And we should only exclude it for FlashInfer
| @@ -205,13 +322,29 @@ def __init__(self, config: VllmConfig) -> None: | |||
| self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num | |||
|
|
|||
| attn_layers = get_layers_from_vllm_config(config, Attention) | |||
| for _, layer in attn_layers.items(): | |||
| if layer.impl.fused_rope_kvcache_supported(): | |||
| if current_platform.is_cuda(): | |||
There was a problem hiding this comment.
Can we consolidate this:
for layer in ...
if not layer.supported()
continue
for is_neox in [True, False]:
if is_cuda()
for use_flashinfer_rope in [True, False]:
RopeQuantReshapeKVCachePattern(...).register()
if is_rocm():
RopeReshapeKVCachePattern(...).register()
| @@ -1005,6 +1005,13 @@ def set_splitting_ops_for_v1( | |||
| # list via reference. | |||
| self.splitting_ops = list(self._attention_ops) | |||
|
|
|||
| # Like attn op, fuse_rope_kvcache op also needs to be a splitting op | |||
There was a problem hiding this comment.
attn metadata access does not matter here. What matters is whether the tensors and shapes are static - can we make them so so this doesn't need to be excluded from CG?
| @@ -83,7 +83,6 @@ def __init__( | |||
| self.rotary_emb = get_rope( | |||
| self.head_dim, | |||
| max_position=config.max_position_embeddings, | |||
There was a problem hiding this comment.
Why is this needed?
There was a problem hiding this comment.
dtype=torch.float32 mainly controls the type of cos_sin_cache. But in the runtime forward, it will always be converted into the same type with query by _match_cos_sin_cache_dtype, so dtype=torch.float32 has no effect but delay the conversion to runtime.
vllm/vllm/model_executor/layers/rotary_embedding/base.py
Lines 182 to 190 in d28d86e
| @@ -1148,6 +1187,23 @@ def build( | |||
| disable_split_kv=self.disable_split_kv, | |||
| ) | |||
| attn_metadata.decode = FIDecode(wrapper=decode_wrapper) | |||
|
|
|||
| # Step 4: Pre-compute params for RoPE + FP8 quantize + KV cache update fusion | |||
There was a problem hiding this comment.
These look cudagraph-safe to me?
There was a problem hiding this comment.
Yes, I have tested with diff cudagraph mode and it currently works with cudagraph_mode=NONE/FULL_DECODE_ONLY/FULL_AND_PIECEWISE. For supporting FULL_AND_PIECEWISE it requires the op excluded from the piecewise graph since it needs to access attn_metadata.
| query_quant_scale: torch.Tensor | None = None, | ||
| query_quant_out: torch.Tensor | None = None, | ||
| ): | ||
| if attn_metadata is None: |
There was a problem hiding this comment.
This means this would not work in piecewise cudagraphs?
| if attn_metadata is None: | ||
| # Profiling run. | ||
| return |
There was a problem hiding this comment.
This will prevent AITER rope-cache from being included in piecewise cudagraphs which we definitely don't want.
| @@ -754,9 +754,9 @@ def fused_output_quant_supported(self, quant_key: "QuantKey"): | |||
| """ | |||
| return False | |||
|
|
|||
| def fused_rope_kvcache_supported(self): | |||
| def fused_rope_kvcache_supported(self, quant_key: "QuantKey | None" = None): | |||
There was a problem hiding this comment.
Nit: can you specify the quant is for query? Maybe call it query_quant_key?
dd6afc1 to
89ffb62
Compare
|
Hi @elvischenv, the pre-commit checks have failed. Please run: uv pip install pre-commit>=4.5.1
pre-commit install
pre-commit run --all-filesThen, commit the changes and push to your branch. For future commits, Tip Is
|
89ffb62 to
b069728
Compare
elvischenv
left a comment
There was a problem hiding this comment.
Hi @ProExpertProg, I have resolved most of the above comments. Could you help review again? Thanks!
I see now, the kernel requires attention metadata which is not built during PIECEWISE warmup/capture. We can keep it excluded for now but we should collect some perf numbers for this kernel inside/outside cudagraphs to see how much this hurts us. And we should only exclude it for FlashInfer
Regarding to benchmarking with kernel inside/outside cudagraphs, I am not sure what this means. This kernel needs to assess attn_metadata so it cannot be added to piecewise cudagraph. It is already included in the full decode cudagraph. Can you elaborate on this?
| @@ -83,7 +83,6 @@ def __init__( | |||
| self.rotary_emb = get_rope( | |||
| self.head_dim, | |||
| max_position=config.max_position_embeddings, | |||
There was a problem hiding this comment.
dtype=torch.float32 mainly controls the type of cos_sin_cache. But in the runtime forward, it will always be converted into the same type with query by _match_cos_sin_cache_dtype, so dtype=torch.float32 has no effect but delay the conversion to runtime.
vllm/vllm/model_executor/layers/rotary_embedding/base.py
Lines 182 to 190 in d28d86e
| @@ -1148,6 +1187,23 @@ def build( | |||
| disable_split_kv=self.disable_split_kv, | |||
| ) | |||
| attn_metadata.decode = FIDecode(wrapper=decode_wrapper) | |||
|
|
|||
| # Step 4: Pre-compute params for RoPE + FP8 quantize + KV cache update fusion | |||
There was a problem hiding this comment.
Yes, I have tested with diff cudagraph mode and it currently works with cudagraph_mode=NONE/FULL_DECODE_ONLY/FULL_AND_PIECEWISE. For supporting FULL_AND_PIECEWISE it requires the op excluded from the piecewise graph since it needs to access attn_metadata.
| @@ -205,13 +322,29 @@ def __init__(self, config: VllmConfig) -> None: | |||
| self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num | |||
|
|
|||
| attn_layers = get_layers_from_vllm_config(config, Attention) | |||
| for _, layer in attn_layers.items(): | |||
| if layer.impl.fused_rope_kvcache_supported(): | |||
| if current_platform.is_cuda(): | |||
|
This pull request has merge conflicts that must be resolved before it can be |
| # Compute slot_mapping consistent with block_table: | ||
| slots = [] | ||
| for i in range(batch_spec.batch_size): | ||
| context_len = batch_spec.seq_lens[i] - batch_spec.query_lens[i] | ||
| for j in range(batch_spec.query_lens[i]): | ||
| global_pos = context_len + j | ||
| physical_block = block_table_tensor[i, global_pos // block_size].item() | ||
| slots.append(physical_block * block_size + global_pos % block_size) | ||
| slot_mapping = torch.tensor(slots, dtype=torch.int64, device=device) |
There was a problem hiding this comment.
is this change a general fix or is it something required specifically for this PR?
There was a problem hiding this comment.
This is a general fix for the baseline(infused path).
|
@elvischenv can you fix the merge conflicts please? I also think some of the fusion failures are related |
d13d4ea to
696f8f6
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> update unit test Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> resolve issue Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Apply suggestions from code review Co-authored-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
696f8f6 to
42731a1
Compare
@mgoin Fixed the conflicts. vllm/vllm/compilation/passes/fusion/rope_kvcache_fusion.py Lines 270 to 274 in 0e39202 There are some hardcoding in conftest.py and may need some fixes:vllm/tests/compile/fusions_e2e/conftest.py Lines 145 to 218 in 1f5ec28 @ProExpertProg Could you look into this after this PR merge to main? Thanks. |
There was a problem hiding this comment.
Does this work without inductor partition? My understanding is that it won't work because the piecewise cudagraphs will simply skip the fused op because attention metadata is not set during piecewise capture.
@mgoin and I discussed this and we think a short-term fix could be to either set attention metadata during piecewise capture and make sure attention doesn't run, or just call the unfused kernel inside the fused op if metadata isn't set.
The proper long-term fix (proposed by @LucasWilkinson) would be to use static buffers and either access them through new metadata for kvcache update which includes the slot mapping, or just read them from the layer.
Could you try the long-term fix first?
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| prefix: str = "model.layers.0.self_attn.attn", | ||
| attn_backend: AttentionBackendEnum = None, |
There was a problem hiding this comment.
Why is this ever None?
| view_to_reshape(gm) | ||
| return gm | ||
|
|
||
| pm.register_replacement( |
There was a problem hiding this comment.
In the case where we're using the layername wildcard, we should add an extra_check method that checks the fusion support for that layer.
Can we actually separate the closure and input ones into separate pattern/replacement classes? They can share a base
| fuse_rope_kvcache: bool = None # type: ignore[assignment] | ||
| """Fuse the QK rope + KV cache ops.""" | ||
|
|
||
| rope_kvcache_fusion_max_token_num: int = 256 |
There was a problem hiding this comment.
Should we use the same threshold for this kernel? This was defaulted because the AITER kernel is slower than unfused above 256 tokens
| and self.use_inductor_graph_partition | ||
| and self.pass_config.fuse_rope_kvcache | ||
| ): | ||
| self.splitting_ops.append( |
There was a problem hiding this comment.
This will work with inductor graph partition. Without it, fused_rope_and_unified_kv_cache_update will remain in the piecewise graph (necessary to perform fusion). But it won't be captured in piecewise cudagraphs because it will be skipped as attention metadata is not set
| fuse_attn_quant=True, | ||
| enable_qk_norm_rope_fusion=True, | ||
| fuse_allreduce_rms=True, | ||
| fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split |
There was a problem hiding this comment.
Instead of disabling the rope-cache fusion in tests, can we adjust the compile range logic?
|
To expand I think we can do something like: then not sure how to move the clamp of the hotpath though edit: actually we might run into issues for |
|
actually i think the easiest would be to just move https://github.com/flashinfer-ai/flashinfer/blob/bf9b1dac855005ffaa57b48ae54cba30642bf213/include/flashinfer/pos_enc.cuh#L800-L1036 into vLLM and modify it to use a slot mapping (and support |
Purpose
Support Flashinfer RoPE+Quant+KV Cache Update fusion kernel
rope_quantize_fp8_append_paged_kv_cache.Depend on flashinfer-ai/flashinfer#2792: Fixed the padding token issue for the kernel when using full cudagraph
Test Plan && Test Result
Fusion pass unit test
pytest -v -s tests/compile/passes/test_rope_kvcache_fusion.py::test_rope_quant_kvcache_fusionModel e2e accuracy
Server cmd:
Fused:
Infused:
Model e2e perf
Fused: about 5% perf gain for GPT-OSS-120b TP8 con8
Infused:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.